import math
import torch
import torch.nn as nn
from torch.autograd import Variable

from onmt.Utils import aeq


class MultiHeadedAttention(nn.Module):
	"""
	Multi-Head Attention module from
	"Attention is All You Need"
	:cite:`DBLP:journals/corr/VaswaniSPUJGKP17`.

	Similar to standard `dot` attention but uses
	multiple attention distributions simulataneously
	to select relevant items.

	.. mermaid::

	   graph BT
		  A[key]
		  B[value]
		  C[query]
		  O[output]
		  subgraph Attn
			D[Attn 1]
			E[Attn 2]
			F[Attn N]
		  end
		  A --> D
		  C --> D
		  A --> E
		  C --> E
		  A --> F
		  C --> F
		  D --> O
		  E --> O
		  F --> O
		  B --> O

	Also includes several additional tricks.

	Args:
	   head_count (int): number of parallel heads
	   model_dim (int): the dimension of keys/values/queries,
		   must be divisible by head_count
	   dropout (float): dropout parameter
	"""
	def __init__(self, head_count, model_dim, dropout=0.1):
		assert model_dim % head_count == 0
		self.dim_per_head = model_dim // head_count
		self.model_dim = model_dim

		super(MultiHeadedAttention, self).__init__()
		self.head_count = head_count

		self.linear_keys = nn.Linear(model_dim,
									 head_count * self.dim_per_head)
		self.linear_values = nn.Linear(model_dim,
									   head_count * self.dim_per_head)
		self.linear_query = nn.Linear(model_dim,
									  head_count * self.dim_per_head)
		self.sm = nn.Softmax(dim=-1)
		self.dropout = nn.Dropout(dropout)
		self.final_linear = nn.Linear(model_dim, model_dim)
		self.final_linear_2 = nn.Linear(model_dim, model_dim)

	def forward(self, key, value, query, mask=None, return_key=False, all_attn=False):
		"""
		Compute the context vector and the attention vectors.

		Args:
		   key (`FloatTensor`): set of `key_len`
				key vectors `[batch, key_len, dim]`
		   value (`FloatTensor`): set of `key_len`
				value vectors `[batch, key_len, dim]`
		   query (`FloatTensor`): set of `query_len`
				 query vectors  `[batch, query_len, dim]`
		   mask: binary mask indicating which keys have
				 non-zero attention `[batch, query_len, key_len]`
		Returns:
		   (`FloatTensor`, `FloatTensor`) :

		   * output context vectors `[batch, query_len, dim]`
		   * one of the attention vectors `[batch, query_len, key_len]`
		"""

		# CHECKS
		batch, k_len, d = key.size()
		batch_, k_len_, d_ = value.size()
		aeq(batch, batch_)
		aeq(k_len, k_len_)
		aeq(d, d_)
		batch_, q_len, d_ = query.size()
		aeq(batch, batch_)
		aeq(d, d_)
		aeq(self.model_dim % 8, 0)
		if mask is not None:
			batch_, q_len_, k_len_ = mask.size()
			aeq(batch_, batch)
			aeq(k_len_, k_len)
			aeq(q_len_ == q_len)
		# END CHECKS

		batch_size = key.size(0)
		dim_per_head = self.dim_per_head
		head_count = self.head_count
		key_len = key.size(1)
		query_len = query.size(1)

		def shape(x):
			return x.view(batch_size, -1, head_count, dim_per_head) \
				.transpose(1, 2)

		def unshape(x):
			return x.transpose(1, 2).contiguous() \
					.view(batch_size, -1, head_count * dim_per_head)

		# 1) Project key, value, and query.
		key_up = shape(self.linear_keys(key))
		value_up = shape(self.linear_values(value))
		query_up = shape(self.linear_query(query))

		# 2) Calculate and scale scores.
		query_up = query_up / math.sqrt(dim_per_head)
		scores = torch.matmul(query_up, key_up.transpose(2, 3))

		if mask is not None:
			mask = mask.unsqueeze(1).expand_as(scores)
			scores = scores.masked_fill(Variable(mask), -1e18)

		# 3) Apply attention dropout and compute context vectors.
		attn = self.sm(scores)

		drop_attn = self.dropout(attn)
		context = unshape(torch.matmul(drop_attn, value_up))

		output = self.final_linear(context)

		batch_, q_len_, d_ = output.size()

		if return_key:
			key_context = unshape(torch.matmul(drop_attn, key_up))
			key_context = self.final_linear_2(key_context)
			output = (output, key_context)
		# CHECK
		
		aeq(q_len, q_len_)
		aeq(batch, batch_)
		aeq(d, d_)

		if all_attn:
			top_attn = attn \
				.view(batch_size, head_count,
					  query_len, key_len)
		else:
			# Return one attn
			top_attn = attn \
				.view(batch_size, head_count,
					  query_len, key_len)[:, 0, :, :] \
				.contiguous()
		# END CHECK
		return output, top_attn
